// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu

// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob /mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c

#include <torch/torch.h>

#include <cmath>
#include <vector>

void DeformablePSROIPoolForward(const at::Tensor data,
                                const at::Tensor bbox,
                                const at::Tensor trans,
                                at::Tensor out,
                                at::Tensor top_count,
                                const int batch,
                                const int channels,
                                const int height,
                                const int width,
                                const int num_bbox,
                                const int channels_trans,
                                const int no_trans,
                                const float spatial_scale,
                                const int output_dim,
                                const int group_size,
                                const int pooled_size,
                                const int part_size,
                                const int sample_per_part,
                                const float trans_std);

void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
                                    const at::Tensor data,
                                    const at::Tensor bbox,
                                    const at::Tensor trans,
                                    const at::Tensor top_count,
                                    at::Tensor in_grad,
                                    at::Tensor trans_grad,
                                    const int batch,
                                    const int channels,
                                    const int height,
                                    const int width,
                                    const int num_bbox,
                                    const int channels_trans,
                                    const int no_trans,
                                    const float spatial_scale,
                                    const int output_dim,
                                    const int group_size,
                                    const int pooled_size,
                                    const int part_size,
                                    const int sample_per_part,
                                    const float trans_std);

void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox,
                                       at::Tensor trans,
                                       at::Tensor out, at::Tensor top_count,
                                       const int no_trans,
                                       const float spatial_scale,
                                       const int output_dim,
                                       const int group_size,
                                       const int pooled_size,
                                       const int part_size,
                                       const int sample_per_part,
                                       const float trans_std)
{
    AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");

    const int batch = input.size(0);
    const int channels = input.size(1);
    const int height = input.size(2);
    const int width = input.size(3);
    const int channels_trans = no_trans ? 2 : trans.size(1);

    const int num_bbox = bbox.size(0);
    if (num_bbox != out.size(0))
        AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
                 out.size(0), num_bbox);

    DeformablePSROIPoolForward(input, bbox, trans, out, top_count,
                               batch, channels, height, width,
                               num_bbox,
                               channels_trans,
                               no_trans,
                               spatial_scale,
                               output_dim,
                               group_size,
                               pooled_size,
                               part_size,
                               sample_per_part,
                               trans_std);
}

void deform_psroi_pooling_cuda_backward(at::Tensor out_grad,
                                        at::Tensor input, at::Tensor bbox,
                                        at::Tensor trans, at::Tensor top_count,
                                        at::Tensor input_grad, at::Tensor trans_grad,
                                        const int no_trans,
                                        const float spatial_scale,
                                        const int output_dim,
                                        const int group_size,
                                        const int pooled_size,
                                        const int part_size,
                                        const int sample_per_part,
                                        const float trans_std)
{
    AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
    AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");

    const int batch = input.size(0);
    const int channels = input.size(1);
    const int height = input.size(2);
    const int width = input.size(3);
    const int channels_trans = no_trans ? 2 : trans.size(1);

    const int num_bbox = bbox.size(0);
    if (num_bbox != out_grad.size(0))
        AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
                 out_grad.size(0), num_bbox);

    DeformablePSROIPoolBackwardAcc(out_grad,
                                   input,
                                   bbox,
                                   trans,
                                   top_count,
                                   input_grad,
                                   trans_grad,
                                   batch, channels, height, width, num_bbox,
                                   channels_trans,
                                   no_trans,
                                   spatial_scale,
                                   output_dim,
                                   group_size,
                                   pooled_size,
                                   part_size,
                                   sample_per_part,
                                   trans_std);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
          "deform psroi pooling forward(CUDA)");
    m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward,
          "deform psroi pooling backward(CUDA)");
}